{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 构建 Relay 模型"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "直接构建："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import testing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tvm\n",
    "import tvm.testing\n",
    "from tvm import relay\n",
    "from tvm.target.target import Target\n",
    "from tvm.relay.backend import Runtime, Executor, graph_executor_codegen"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def add(shape, dtype):\n",
    "    lhs = relay.var(\"A\", shape=shape, dtype=dtype)\n",
    "    rhs = relay.var(\"B\", shape=shape, dtype=dtype)\n",
    "    out = relay.add(lhs, rhs)\n",
    "    expr = relay.Function((lhs, rhs), out)\n",
    "    mod = tvm.IRModule.from_expr(expr)\n",
    "    return mod"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div class=\"highlight\" style=\"background: \"><pre style=\"line-height: 125%;\"><span></span><span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #AA22FF\">@main</span>(<span style=\"color: #AA22FF; font-weight: bold\">%</span>A: Tensor[(<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">8</span>), float32], <span style=\"color: #AA22FF; font-weight: bold\">%</span>B: Tensor[(<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">8</span>), float32]) {\n",
       "  add(<span style=\"color: #AA22FF; font-weight: bold\">%</span>A, <span style=\"color: #AA22FF; font-weight: bold\">%</span>B)\n",
       "}\n",
       "</pre></div>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "mod = add((1, 8), \"float32\")\n",
    "mod.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "target = tvm.target.Target(\"llvm\")\n",
    "target, target_host = tvm.target.Target.canon_target_and_host(target)\n",
    "mod, _ = relay.optimize(mod, target)\n",
    "grc = graph_executor_codegen.GraphExecutorCodegen(None, target)\n",
    "_, lowered_funcs, _ = grc.codegen(mod, mod[\"main\"])\n",
    "_ = relay.backend._backend.build(lowered_funcs, target)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "One or more operators have not been tuned. Please tune your model for better performance. Use DEBUG logging level to see more details.\n"
     ]
    }
   ],
   "source": [
    "\"\"\"Test to build a nn model and get schedule_record from build_module\"\"\"\n",
    "from tvm.relay import testing\n",
    "def check_schedule(executor):\n",
    "    for func_name, func_meta in executor.function_metadata.items():\n",
    "        # check converted op only\n",
    "        if \"main\" not in func_name:\n",
    "            primfunc = list(func_meta.relay_primfuncs.values())[0]\n",
    "            # make sure schedule is well-stored in function metadata\n",
    "            assert \"schedule\" in primfunc.attrs\n",
    "            sch = primfunc.attrs[\"schedule\"]\n",
    "            assert len(sch.schedule_record) == len(sch.primitive_record)\n",
    "\n",
    "relay_mod, params = testing.mobilenet.get_workload(batch_size=1, dtype=\"float32\")\n",
    "target_llvm = tvm.target.Target(\"llvm\")\n",
    "config = {\"te.keep_schedule_record\": True}\n",
    "\n",
    "with tvm.transform.PassContext(opt_level=3, config=config):\n",
    "    aot_executor_factory = relay.build(\n",
    "        relay_mod,\n",
    "        target_llvm,\n",
    "        runtime=Runtime(\"cpp\"),\n",
    "        executor=Executor(\"aot\"),\n",
    "        params=params,\n",
    "    )\n",
    "    graph_executor_factory = relay.build(\n",
    "        relay_mod,\n",
    "        target_llvm,\n",
    "        params=params,\n",
    "    )\n",
    "\n",
    "check_schedule(aot_executor_factory)\n",
    "check_schedule(graph_executor_factory)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py312x",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
